# encoding: utf-8

from dataclasses import dataclass
from transformers import PretrainedConfig


class ModelConfig(PretrainedConfig):
    model_type = "classification"
    dim = 128
    n_layers = 8
    n_heads = 32
    n_kv_heads = 8
    vocab_size = 6400
    hidden_dim = None
    multiple_of = 64
    norm_eps = 1e-5
    max_seq_len = 128
    dropout = 0.0
    flash_attn = True




if __name__ == '__main__':
    cfg = ModelConfig()
    print(dir(cfg))
